learn-agent-unit4 / tools.py
Vindemia's picture
improve prompt and tooling
2408f2c
import os
import whisper
import mimetypes
import json
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.tools import BraveSearch
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain.tools import Tool, tool
from youtube_transcript_api import YouTubeTranscriptApi
from pytube import extract
from pydantic import BaseModel, Field
from langchain_experimental.utilities import PythonREPL
@tool
def get_youtube_transcript(page_url: str) -> str:
"""Get the transcript of a YouTube video
Args:
page_url (str): YouTube URL of the video
"""
try:
# get video ID from URL
video_id = extract.video_id(page_url)
# get transcript
ytt_api = YouTubeTranscriptApi()
transcript = ytt_api.fetch(video_id)
# keep only text
txt = '\n'.join([s.text for s in transcript.snippets])
return txt
except Exception as e:
return f"get_youtube_transcript failed: {e}"
@tool
def multiply(a: float, b: float) -> float:
"""Multiplies two numbers.
Args:
a (float): the first number
b (float): the second number
"""
return a * b
@tool
def add(a: float, b: float) -> float:
"""Adds two numbers.
Args:
a (float): the first number
b (float): the second number
"""
return a + b
@tool
def subtract(a: float, b: float) -> int:
"""Subtracts two numbers.
Args:
a (float): the first number
b (float): the second number
"""
return a - b
@tool
def divide(a: float, b: float) -> float:
"""Divides two numbers.
Args:
a (float): the first float number
b (float): the second float number
"""
if b == 0:
raise ValueError("Cannot divided by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers.
Args:
a (int): the first number
b (int): the second number
"""
return a % b
@tool
def power(a: float, b: float) -> float:
"""Get the power of two numbers.
Args:
a (float): the first number
b (float): the second number
"""
return a**b
@tool
def get_web_search_result(query:str):
"""Fetches information on the web based on quey.
Args:
query: The search query.
Returns:
"""
print("get_web_search_result")
search_tool = DuckDuckGoSearchRun()
results = search_tool.invoke(query)
return results
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 5 results.
Args:
query: The search query.
Returns:
An array documents
"""
print("wiki_search")
search_docs = WikipediaLoader(query=query, load_max_docs=5).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"wiki_results": formatted_search_docs}
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query.
Returns:
An array documents
"""
print("arvix_search")
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
return {"arvix_results": formatted_search_docs}
@tool
def transcribe_audio(file_path: str):
"""
Transcribes an audio file to text using local Whisper model.
Args:
file_path: Path to the audio file
Returns:
A dictionary containing the transcription and metadata
"""
try:
print(f"Transcribing audio file: {file_path}")
# Validate file exists
if not os.path.exists(file_path):
return {
"status": "error",
"message": f"File not found: {file_path}"
}
# Load a Whisper model - we'll use the small model for better performance
# Options include: tiny, base, small, medium, large
model = whisper.load_model("small")
# Transcribe the audio
result = model.transcribe(file_path)
print({
"status": "success",
"transcription": result["text"],
"language": result.get("language", "unknown"),
"file_path": file_path
})
# Return the transcription and metadata
return {
"status": "success",
"transcription": result["text"],
"language": result.get("language", "unknown"),
"file_path": file_path
}
except Exception as e:
print({
"status": "error",
"message": f"Error transcribing audio: {str(e)}"
})
return {
"status": "error",
"message": f"Error transcribing audio: {str(e)}"
}
class PythonREPLInput(BaseModel):
code: str = Field(description="The Python code string to execute.")
python_repl = PythonREPL()
python_repl_tool = Tool(
name="python_repl",
description="""A Python REPL shell (Read-Eval-Print Loop).
Use this to execute single or multi-line python commands.
Input should be syntactically valid Python code.
Always end your code with `print(...)` to see the output.
Do NOT execute code that could be harmful to the host system.
You are allowed to download files from URLs.
Do not use this tool as a web search.
Do NOT send commands that block indefinitely (e.g., `input()`).""",
func=python_repl.run,
args_schema=PythonREPLInput
)
available_tools = [
# get_web_search_result,
wiki_search,
arvix_search,
transcribe_audio,
python_repl_tool,
multiply,
add,
subtract,
divide,
modulus,
power,
get_youtube_transcript,
BraveSearch.from_api_key(
api_key=os.getenv("BRAVE_SEARCH_API_KEY"),
search_kwargs={"count": 5}
)
]