GAIA_Project / tools.py
sarthak1311's picture
updated file save paths
e070b43
import os
import math
import contextlib
import io
import traceback
from pathlib import Path
from typing import Literal
import pandas as pd
from groq import Groq
from langchain_core.tools import tool
from langchain_community.tools import ArxivQueryRun, WikipediaQueryRun
from langchain_community.utilities import ArxivAPIWrapper, WikipediaAPIWrapper
from langchain_community.document_loaders import TextLoader
from langchain_tavily import TavilySearch
from urllib.parse import parse_qs, urlparse
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled
import yt_dlp
# for local testing
# from dotenv import load_dotenv
# load_dotenv()
__all__ = [
"calculator",
"Web_Search",
"Arxiv_Search",
"Wikipedia_Search",
"get_yt_video_info_metadata",
"get_yt_video_transcript",
"analyze_excel_file",
"read_file",
"analyze_image",
"transcribe_audio_file",
"save_file_temp",
"execute_code_file"
]
tavily_api_key = os.getenv("TAVILY_API_KEY", None)
if tavily_api_key is None:
raise ValueError("TAVILY_API_KEY is not set in environment variables")
# MATHEMATICAL TOOLS
@tool
def calculator(a:float, b:float, operation:Literal["add", "subtract", "multiply", "divide", "round_off", "power_exponent"]) -> float:
"""Use this to perform the following mathematical operations:
"add", "subtract", "multiply", "divide", "round_off", "power_exponent"
Arguments:
a : First Number
b : Second Number
operation : the operation to perform
"""
if operation.lower() == "add":
return a+b
elif operation.lower() == "subtract":
return a-b
elif operation.lower() == "multiply":
return a*b
elif operation.lower() == "divide":
return a/b
elif operation.lower() == "round_off":
return int(a/b)
elif operation.lower() == "power_exponent":
return math.pow(a, b)
else:
return "Undefined operation. Please Choose from 'add', 'subtract', 'multiply', 'divide', 'round_off', 'power_exponent'"
# SEARCH TOOLS
@tool
def Web_Search(query:str) -> dict:
"""performs a web search for a query
Arguments:
query : what you want to search on the internet
"""
tavily = TavilySearch(max_results=3, tavily_api_key=tavily_api_key)
results = tavily.invoke({"query":query})
return {"web_results":results}
@tool
def Arxiv_Search(research_paper_name:str) -> str:
"""performs a search in arxiv database and retrieves the results
Arguments:
research_paper_name (str) : name of the research paper to find
"""
arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper())
return arxiv.invoke(research_paper_name)
@tool
def Wikipedia_Search(query:str)->dict:
"""performs a search in wikipedia database and retrieves the results
Arguments:
query (str) : what to search wikipedia for
"""
wiki = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
wiki_results = wiki.run(query)
return {"wiki_response":wiki_results.split("Page: ")[1:]}
# YT RELATED TOOLS
@tool
def get_yt_video_transcript(url:str) -> str:
"""
Get transcript for a youtube video
if this tool returns -> {"exception": "Video transcript not available, try another way!"},
then think of another way to analyze the given youtube video
Arguments:
video_url : str - url of the youtube video
Returns:
video transcript as a string
"""
def extract_video_id(url):
return parse_qs(urlparse(url).query)["v"][0]
def get_transcript(url):
video_id = extract_video_id(url)
api_obj = YouTubeTranscriptApi()
try:
transcript = api_obj.fetch(video_id=video_id)
except TranscriptsDisabled:
print("Subtitles are disabled for this video")
return {"exception": "Video transcript not available, try another way!"}
except Exception as e:
print(f"some error occured while getting yt video transcript : {e}")
return " ".join([entry.text for entry in transcript])
transcript = get_transcript(url=url)
return transcript
@tool
def get_yt_video_info_metadata(url:str) -> dict:
"""
Get metadata for any youtube video
Arguments:
video_url : str - url of the youtube video
Returns:
json data in the following structure:
{
"title": <tile-of-the-youtube-video>,
"description": <description-of-the-youtube-video>,
"uploader": <account-name-from-which-video-got-uploaded>,
"duration": <duration-of-the-youtube-video-in-seconds>,
"tags": <list of words tagged to the video by the uploader>
}
"""
ydl_opts = {
'quiet': True,
'skip_download': True,
'extract_flat': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url, download=False)
return {
"title": info.get("title"),
"description": info.get("description"),
"uploader": info.get("uploader"),
"duration": info.get("duration"),
"tags": info.get("tags")
}
# Document Loaders
@tool
def analyze_excel_file(file_name:str) -> str:
"""
Analyze an excel file using python pandas
Arguments:
file_name (str): name of the excel file
"""
try:
file_path = Path(__file__).cwd().joinpath(file_name)
df = pd.read_excel(file_path)
result = (f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns \n")
result += f"Columns : {', '.join(map(str,df.columns))}\n\n"
result += "Summary Statistics:\n"
result += str(df.describe())
return {"basic_sheet_info":result, "sheet_content":df}
except Exception as e:
return f"An Exception occured : {e}"
@tool
def read_file(file_name:str) -> str:
"""
read the contents of any type of text file including code files
Arguments:
file_name (str): name of the text file
"""
try:
file_path = Path(__file__).cwd().joinpath(file_name)
docs = TextLoader(file_path, autodetect_encoding=True)
result = docs.load()
return result[0].page_content
except Exception as e:
return f"An Exception occured : {e}"
@tool
def analyze_image(image_name:str, query:str)->str:
"""
give the query and image path to this tool for you to get answers related to provided image
Arguments
image_name (str): path of image file
query (str): query related to the image
"""
try:
import google.generativeai as genai
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel(model_name="gemini-2.5-pro", generation_config=genai.GenerationConfig(temperature=0))
image_path = Path(__file__).cwd().joinpath(image_name)
response = model.generate_content([query, image_path])
return response.text
except Exception as e:
return f"An Exception occured : {e}"
@tool
def transcribe_audio_file(audio_file_name:str)->str:
"""
Get the transcription of any audio file
Arguments:
audio_file_name (str): nameof the audio file to be transcribed
"""
client = Groq(api_key=os.getenv("LLM_API_KEY", None))
filename = Path(__file__).cwd().joinpath(audio_file_name)
with open(filename, "rb") as file:
translation = client.audio.translations.create(
file = (str(filename), file.read()),
model="whisper-large-v3",
response_format="json",
temperature=0
)
return translation.text
# Save files
@tool
def save_file_temp(file_extenstion:str, file_name:str, file_data:str) -> str:
"""
Save something as a file on a tempory basis,
Arguments:
file_extension (str): extension(type) of the file
file_name (str): name to save the file with, without extension
file_data (str): data to write in the file
"""
file_extenstion = file_extenstion.strip()
file_name = file_name.strip()
if not isinstance(file_data, str):
return "file data not of string format, cannot save. Retry with string type file_data"
valid_extenions = [".py", ".js", ".txt", ".csv", ".c", ".cpp", ".java"]
if "." in file_name:
file_name = file_name.split(".")[0]
file_extenstion = file_extenstion.strip()
if "." not in file_extenstion:
file_extenstion = "."+ file_extenstion
if file_extenstion.lower() not in valid_extenions:
return f"❌ unsupported file extension provided, supported extensions are: {' , ' .join(valid_extenions)}"
save_name = file_name+file_extenstion
with open(save_name, "w") as f:
f.write(file_data)
return {"save_file_name":save_name}
# Code Executors
@tool
def execute_code_file(code: str, language: str) -> str:
"""
Executes the given code and returns its output.
Arguments:
code: The source code to execute
language: The programming language of the code (currently supports "python")
Returns:
The standard output or result of the code execution
"""
supported_languages = ["python"]
if language.lower() == "python":
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
output_buffer = io.StringIO()
output_err = io.StringIO()
global_imports = {
"__builtins__":__builtins__,
"np":np,
"pd":pd,
"plt":plt
}
result = {
"status" : "error",
"result" : None,
"stdout" : "",
"stderr" : ""
}
try:
with contextlib.redirect_stdout(output_buffer), contextlib.redirect_stderr(output_err):
exec_result = exec(code, global_imports)
result["status"] = "success"
result["stdout"] = output_buffer.getvalue()
result["result"] = exec_result
return result
except Exception as e:
result["status"] = "error"
result["stderr"] = f"{output_err.getvalue()}\n{traceback.format_exc()}"
return result
else:
return f"❌ unsupported language code provided, supported languages are: {' , '.join(supported_languages)}"