FinalSubmission / gaiaX /tools.py
derkaal's picture
Upload folder using huggingface_hub
c7eca3d verified
#!/usr/bin/env python3
"""
LangChain tools module for GAIA Benchmark Agent.
This module defines the custom tools used by the LangChain agent
to interact with the GAIA benchmark API and process questions,
as well as external information sources like search engines and YouTube.
"""
import json
import tempfile
import re
import os
from typing import Dict, Any, List, Optional
from pathlib import Path
from langchain.tools import BaseTool, tool
from gaiaX.config import logger, API_BASE_URL, SERPAPI_API_KEY, YOUTUBE_API_KEY, TAVILY_API_KEY, WHISPER_API_KEY
from gaiaX.api import download_file_for_task, get_question_details
@tool
def fetch_question_details(task_id: str, api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
"""
Get detailed information about a specific question/task.
Args:
task_id: The ID of the task to get details for
api_base_url: Base URL for the GAIA API
Returns:
Dictionary containing question details
"""
return get_question_details(task_id, api_base_url)
@tool
def fetch_context_file(task_id: str, api_base_url: str = API_BASE_URL) -> str:
"""
Download and read the context file for a specific task.
Args:
task_id: The ID of the task to download the file for
api_base_url: Base URL for the GAIA API
Returns:
String containing the file contents or error message
"""
try:
# Create a temporary directory to store the file
with tempfile.TemporaryDirectory() as temp_dir:
file_path = download_file_for_task(api_base_url, task_id, temp_dir)
file_ext = Path(file_path).suffix.lower()
# Check if it's an audio file
audio_extensions = ['.mp3', '.wav', '.m4a', '.flac', '.aac', '.ogg']
if file_ext in audio_extensions:
logger.info(f"Audio file detected ({file_ext}). Attempting transcription.")
return f"Audio file detected ({file_ext}). Use the transcribe_audio tool with the path: {file_path}"
# Try to read the file as text
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
# If it's not a text file, try to read it as binary and provide info
file_size = Path(file_path).stat().st_size
file_ext = Path(file_path).suffix
# Check if it might be an audio file with wrong extension
if file_size > 1024 and file_size < 100 * 1024 * 1024: # Between 1KB and 100MB
return f"Binary file detected ({file_ext}, {file_size} bytes). This might be an audio file. Try using the transcribe_audio tool with the path: {file_path}"
else:
return f"Binary file detected ({file_ext}, {file_size} bytes). This file cannot be displayed as text. Please use specialized tools to analyze this type of file."
except Exception as e:
logger.error(f"Error fetching context file: {str(e)}")
return f"Error fetching context file: {str(e)}"
# Define a class for each tool to make them more configurable
class QuestionDetailsTool(BaseTool):
"""Tool for fetching question details from the GAIA API."""
name = "get_question_details"
description = "Get detailed information about a specific question/task"
def _run(self, task_id: str, api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
"""Execute the tool."""
return get_question_details(task_id, api_base_url)
def _arun(self, task_id: str, api_base_url: str = API_BASE_URL):
"""Execute the tool asynchronously."""
raise NotImplementedError("Async version not implemented")
class ContextFileTool(BaseTool):
"""Tool for fetching and reading context files for tasks."""
name = "fetch_context_file"
description = "Download and read the context file for a specific task"
def _run(self, task_id: str, api_base_url: str = API_BASE_URL) -> str:
"""Execute the tool."""
return fetch_context_file(task_id, api_base_url)
def _arun(self, task_id: str, api_base_url: str = API_BASE_URL):
"""Execute the tool asynchronously."""
raise NotImplementedError("Async version not implemented")
@tool
def search_youtube(query: str, max_results: int = 3, api_key: str = YOUTUBE_API_KEY) -> str:
"""
Search for YouTube videos related to a query and return information about them.
Args:
query: The search query
max_results: Maximum number of results to return (default: 3)
api_key: YouTube API key
Returns:
String containing information about the videos
"""
if not api_key:
return "YouTube API key is not available. Cannot search YouTube."
try:
from googleapiclient.discovery import build
# Initialize the YouTube API client
youtube = build('youtube', 'v3', developerKey=api_key)
# Execute the search request
search_response = youtube.search().list(
q=query,
part='id,snippet',
maxResults=max_results,
type='video'
).execute()
# Process the results
results = []
for item in search_response.get('items', []):
video_id = item['id']['videoId']
title = item['snippet']['title']
description = item['snippet']['description']
channel = item['snippet']['channelTitle']
published_at = item['snippet']['publishedAt']
# Get video details (duration, view count, etc.)
video_response = youtube.videos().list(
part='contentDetails,statistics',
id=video_id
).execute()
video_info = video_response['items'][0]
duration = video_info['contentDetails']['duration']
view_count = video_info['statistics'].get('viewCount', 'N/A')
like_count = video_info['statistics'].get('likeCount', 'N/A')
# Format the result
video_url = f"https://www.youtube.com/watch?v={video_id}"
result = {
"title": title,
"url": video_url,
"channel": channel,
"published_at": published_at,
"duration": duration,
"view_count": view_count,
"like_count": like_count,
"description": description
}
results.append(result)
# Format the results as a string
formatted_results = ""
for i, result in enumerate(results, 1):
formatted_results += f"Video {i}:\n"
formatted_results += f"Title: {result['title']}\n"
formatted_results += f"URL: {result['url']}\n"
formatted_results += f"Channel: {result['channel']}\n"
formatted_results += f"Published: {result['published_at']}\n"
formatted_results += f"Duration: {result['duration']}\n"
formatted_results += f"Views: {result['view_count']}\n"
formatted_results += f"Likes: {result['like_count']}\n"
formatted_results += f"Description: {result['description'][:200]}...\n\n"
return formatted_results
except ImportError:
return "Required packages not installed. Please install googleapiclient with: pip install google-api-python-client"
except Exception as e:
logger.error(f"Error searching YouTube: {str(e)}")
return f"Error searching YouTube: {str(e)}"
@tool
def get_youtube_transcript(video_url: str) -> str:
"""
Get the transcript of a YouTube video.
Args:
video_url: URL of the YouTube video
Returns:
String containing the transcript
"""
try:
from youtube_transcript_api import YouTubeTranscriptApi
# Extract video ID from URL
video_id_match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11}).*', video_url)
if not video_id_match:
return f"Invalid YouTube URL: {video_url}"
video_id = video_id_match.group(1)
# Get the transcript
transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
# Format the transcript
transcript = ""
for entry in transcript_list:
start_time = entry['start']
text = entry['text']
minutes = int(start_time // 60)
seconds = int(start_time % 60)
timestamp = f"{minutes:02d}:{seconds:02d}"
transcript += f"[{timestamp}] {text}\n"
return transcript
except ImportError:
return "Required packages not installed. Please install youtube-transcript-api with: pip install youtube-transcript-api"
except Exception as e:
logger.error(f"Error getting YouTube transcript: {str(e)}")
return f"Error getting YouTube transcript: {str(e)}"
@tool
def transcribe_audio(file_path: str, api_key: str = WHISPER_API_KEY) -> str:
"""
Transcribe audio file to text using OpenAI's Whisper API with Google Speech Recognition fallback.
Args:
file_path: Path to the audio file
api_key: OpenAI API key for Whisper
Returns:
String containing the transcribed text
"""
try:
import speech_recognition as sr
from pydub import AudioSegment
import os
# Check if file exists
if not os.path.exists(file_path):
return f"Error: File not found at {file_path}"
# Get file extension
file_ext = os.path.splitext(file_path)[1].lower()
# Convert audio file to WAV format if needed
temp_wav_path = None
if file_ext != '.wav':
try:
logger.info(f"Converting {file_ext} file to WAV format")
temp_wav_path = os.path.join(os.path.dirname(file_path), "temp_audio.wav")
audio = AudioSegment.from_file(file_path)
audio.export(temp_wav_path, format="wav")
file_path = temp_wav_path
logger.info(f"Converted audio saved to {temp_wav_path}")
except Exception as e:
logger.error(f"Error converting audio: {str(e)}")
return f"Error converting audio: {str(e)}"
# Initialize recognizer
recognizer = sr.Recognizer()
# Load audio file
with sr.AudioFile(file_path) as source:
audio_data = recognizer.record(source)
# Try OpenAI Whisper API first
if api_key:
try:
logger.info("Attempting transcription with OpenAI Whisper API")
import openai
client = openai.OpenAI(api_key=api_key)
with open(file_path, "rb") as audio_file:
transcript = client.audio.transcriptions.create(
model="whisper-1",
file=audio_file
)
# Clean up temporary file if created
if temp_wav_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
return transcript.text
except Exception as e:
logger.error(f"Error with Whisper API: {str(e)}")
logger.info("Falling back to Google Speech Recognition")
# Fallback to Google Speech Recognition
try:
logger.info("Using Google Speech Recognition")
text = recognizer.recognize_google(audio_data)
# Clean up temporary file if created
if temp_wav_path and os.path.exists(temp_wav_path):
os.remove(temp_wav_path)
return text
except sr.UnknownValueError:
return "Google Speech Recognition could not understand the audio"
except sr.RequestError as e:
return f"Could not request results from Google Speech Recognition service: {str(e)}"
except ImportError:
return "Required packages not installed. Please install pydub and SpeechRecognition with: pip install pydub SpeechRecognition"
except Exception as e:
logger.error(f"Error transcribing audio: {str(e)}")
return f"Error transcribing audio: {str(e)}"
@tool
def extract_ingredients_from_audio(file_path: str, api_key: str = WHISPER_API_KEY) -> str:
"""
Extract ingredients list from a recipe audio file.
Args:
file_path: Path to the audio file
api_key: OpenAI API key for Whisper
Returns:
String containing the extracted ingredients
"""
try:
# First transcribe the audio
transcript = transcribe_audio(file_path, api_key)
if transcript.startswith("Error") or transcript.startswith("Could not"):
return transcript
# Extract ingredients using pattern matching
ingredients_section = None
# Common patterns that indicate ingredients sections in recipes
patterns = [
r"(?i)ingredients[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
r"(?i)you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
r"(?i)what you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
r"(?i)here's what you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)"
]
for pattern in patterns:
match = re.search(pattern, transcript, re.DOTALL)
if match:
ingredients_section = match.group(1).strip()
break
if not ingredients_section:
# If no clear ingredients section, try to extract using common ingredient patterns
potential_ingredients = []
# Look for common measurement patterns
measurements = r"(?i)(\d+(?:\s+\d+/\d+)?|\d+/\d+)\s*(?:cup|cups|tablespoon|tbsp|teaspoon|tsp|ounce|oz|pound|lb|gram|g|kg|ml|l|pinch|dash|handful|clove|cloves|bunch|can|package|pkg|bottle)"
measurement_matches = re.finditer(measurements, transcript)
for match in measurement_matches:
# Get the sentence containing this measurement
start = max(0, match.start() - 50)
end = min(len(transcript), match.end() + 50)
context = transcript[start:end]
potential_ingredients.append(context)
if potential_ingredients:
ingredients_section = "\n".join(potential_ingredients)
else:
return "Could not identify ingredients section in the audio. Please provide a clearer recording or manually list the ingredients."
# Format the ingredients as a list
ingredients_lines = ingredients_section.split("\n")
formatted_ingredients = []
for line in ingredients_lines:
line = line.strip()
if line:
# Remove any non-ingredient text
if not re.search(r"(?i)(instruction|direction|method|step|preparation|preheat|mix|stir|cook|bake)", line):
formatted_ingredients.append(f"- {line}")
if not formatted_ingredients:
# If no clear ingredient lines, just return the whole section
return f"Extracted Ingredients:\n{ingredients_section}"
return "Extracted Ingredients:\n" + "\n".join(formatted_ingredients)
except Exception as e:
logger.error(f"Error extracting ingredients: {str(e)}")
return f"Error extracting ingredients: {str(e)}"
# Function to get all available tools
def get_tools(include_search: bool = True, tavily_api_key: str = None,
serpapi_api_key: str = None, youtube_api_key: str = None,
whisper_api_key: str = None):
"""
Get all available tools for the agent.
Args:
include_search: Whether to include search tools
tavily_api_key: Tavily API key for search functionality
serpapi_api_key: SerpAPI key for search functionality
youtube_api_key: YouTube API key for video content access
whisper_api_key: OpenAI Whisper API key for audio transcription
Returns:
List of tools
"""
tools = [
fetch_question_details,
fetch_context_file
]
# Add audio processing tools
tools.append(transcribe_audio)
tools.append(extract_ingredients_from_audio)
logger.info("Audio processing tools added to agent tools")
# Add YouTube tools
if youtube_api_key:
tools.append(search_youtube)
tools.append(get_youtube_transcript)
logger.info("YouTube tools added to agent tools")
# Add search tools if search is enabled
if include_search:
# Add Tavily search if API key is available
if tavily_api_key:
try:
from langchain_community.tools.tavily_search import TavilySearchResults
tavily_search = TavilySearchResults(
max_results=7, # Increased from 3 to get more comprehensive results
api_key=tavily_api_key
)
tools.append(tavily_search)
logger.info("Tavily search tool added to agent tools")
except ImportError:
logger.warning("Could not import TavilySearchResults. Tavily search will be disabled.")
except Exception as e:
logger.warning(f"Error initializing Tavily search tool: {e}")
# Add SerpAPI search if API key is available
if serpapi_api_key:
try:
from langchain_community.utilities.serpapi import SerpAPIWrapper
from langchain.tools import Tool
search = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
serpapi_tool = Tool(
name="SerpAPI Search",
description="A search engine. Useful for when you need to answer questions about current events or the current state of the world. Input should be a search query.",
func=search.run
)
tools.append(serpapi_tool)
logger.info("SerpAPI search tool added to agent tools")
except ImportError:
logger.warning("Could not import SerpAPIWrapper. SerpAPI search will be disabled.")
except Exception as e:
logger.warning(f"Error initializing SerpAPI search tool: {e}")
return tools