|
|
| """
|
| LangChain tools module for GAIA Benchmark Agent.
|
|
|
| This module defines the custom tools used by the LangChain agent
|
| to interact with the GAIA benchmark API and process questions,
|
| as well as external information sources like search engines and YouTube.
|
| """
|
|
|
| import json
|
| import tempfile
|
| import re
|
| import os
|
| from typing import Dict, Any, List, Optional
|
| from pathlib import Path
|
|
|
| from langchain.tools import BaseTool, tool
|
|
|
| from gaiaX.config import logger, API_BASE_URL, SERPAPI_API_KEY, YOUTUBE_API_KEY, TAVILY_API_KEY, WHISPER_API_KEY
|
| from gaiaX.api import download_file_for_task, get_question_details
|
|
|
| @tool
|
| def fetch_question_details(task_id: str, api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
|
| """
|
| Get detailed information about a specific question/task.
|
|
|
| Args:
|
| task_id: The ID of the task to get details for
|
| api_base_url: Base URL for the GAIA API
|
|
|
| Returns:
|
| Dictionary containing question details
|
| """
|
| return get_question_details(task_id, api_base_url)
|
|
|
| @tool
|
| def fetch_context_file(task_id: str, api_base_url: str = API_BASE_URL) -> str:
|
| """
|
| Download and read the context file for a specific task.
|
|
|
| Args:
|
| task_id: The ID of the task to download the file for
|
| api_base_url: Base URL for the GAIA API
|
|
|
| Returns:
|
| String containing the file contents or error message
|
| """
|
| try:
|
|
|
| with tempfile.TemporaryDirectory() as temp_dir:
|
| file_path = download_file_for_task(api_base_url, task_id, temp_dir)
|
| file_ext = Path(file_path).suffix.lower()
|
|
|
|
|
| audio_extensions = ['.mp3', '.wav', '.m4a', '.flac', '.aac', '.ogg']
|
| if file_ext in audio_extensions:
|
| logger.info(f"Audio file detected ({file_ext}). Attempting transcription.")
|
| return f"Audio file detected ({file_ext}). Use the transcribe_audio tool with the path: {file_path}"
|
|
|
|
|
| try:
|
| with open(file_path, 'r', encoding='utf-8') as f:
|
| return f.read()
|
| except UnicodeDecodeError:
|
|
|
| file_size = Path(file_path).stat().st_size
|
| file_ext = Path(file_path).suffix
|
|
|
|
|
| if file_size > 1024 and file_size < 100 * 1024 * 1024:
|
| return f"Binary file detected ({file_ext}, {file_size} bytes). This might be an audio file. Try using the transcribe_audio tool with the path: {file_path}"
|
| else:
|
| return f"Binary file detected ({file_ext}, {file_size} bytes). This file cannot be displayed as text. Please use specialized tools to analyze this type of file."
|
| except Exception as e:
|
| logger.error(f"Error fetching context file: {str(e)}")
|
| return f"Error fetching context file: {str(e)}"
|
|
|
|
|
| class QuestionDetailsTool(BaseTool):
|
| """Tool for fetching question details from the GAIA API."""
|
|
|
| name = "get_question_details"
|
| description = "Get detailed information about a specific question/task"
|
|
|
| def _run(self, task_id: str, api_base_url: str = API_BASE_URL) -> Dict[str, Any]:
|
| """Execute the tool."""
|
| return get_question_details(task_id, api_base_url)
|
|
|
| def _arun(self, task_id: str, api_base_url: str = API_BASE_URL):
|
| """Execute the tool asynchronously."""
|
| raise NotImplementedError("Async version not implemented")
|
|
|
| class ContextFileTool(BaseTool):
|
| """Tool for fetching and reading context files for tasks."""
|
|
|
| name = "fetch_context_file"
|
| description = "Download and read the context file for a specific task"
|
|
|
| def _run(self, task_id: str, api_base_url: str = API_BASE_URL) -> str:
|
| """Execute the tool."""
|
| return fetch_context_file(task_id, api_base_url)
|
|
|
| def _arun(self, task_id: str, api_base_url: str = API_BASE_URL):
|
| """Execute the tool asynchronously."""
|
| raise NotImplementedError("Async version not implemented")
|
|
|
| @tool
|
| def search_youtube(query: str, max_results: int = 3, api_key: str = YOUTUBE_API_KEY) -> str:
|
| """
|
| Search for YouTube videos related to a query and return information about them.
|
|
|
| Args:
|
| query: The search query
|
| max_results: Maximum number of results to return (default: 3)
|
| api_key: YouTube API key
|
|
|
| Returns:
|
| String containing information about the videos
|
| """
|
| if not api_key:
|
| return "YouTube API key is not available. Cannot search YouTube."
|
|
|
| try:
|
| from googleapiclient.discovery import build
|
|
|
|
|
| youtube = build('youtube', 'v3', developerKey=api_key)
|
|
|
|
|
| search_response = youtube.search().list(
|
| q=query,
|
| part='id,snippet',
|
| maxResults=max_results,
|
| type='video'
|
| ).execute()
|
|
|
|
|
| results = []
|
| for item in search_response.get('items', []):
|
| video_id = item['id']['videoId']
|
| title = item['snippet']['title']
|
| description = item['snippet']['description']
|
| channel = item['snippet']['channelTitle']
|
| published_at = item['snippet']['publishedAt']
|
|
|
|
|
| video_response = youtube.videos().list(
|
| part='contentDetails,statistics',
|
| id=video_id
|
| ).execute()
|
|
|
| video_info = video_response['items'][0]
|
| duration = video_info['contentDetails']['duration']
|
| view_count = video_info['statistics'].get('viewCount', 'N/A')
|
| like_count = video_info['statistics'].get('likeCount', 'N/A')
|
|
|
|
|
| video_url = f"https://www.youtube.com/watch?v={video_id}"
|
| result = {
|
| "title": title,
|
| "url": video_url,
|
| "channel": channel,
|
| "published_at": published_at,
|
| "duration": duration,
|
| "view_count": view_count,
|
| "like_count": like_count,
|
| "description": description
|
| }
|
| results.append(result)
|
|
|
|
|
| formatted_results = ""
|
| for i, result in enumerate(results, 1):
|
| formatted_results += f"Video {i}:\n"
|
| formatted_results += f"Title: {result['title']}\n"
|
| formatted_results += f"URL: {result['url']}\n"
|
| formatted_results += f"Channel: {result['channel']}\n"
|
| formatted_results += f"Published: {result['published_at']}\n"
|
| formatted_results += f"Duration: {result['duration']}\n"
|
| formatted_results += f"Views: {result['view_count']}\n"
|
| formatted_results += f"Likes: {result['like_count']}\n"
|
| formatted_results += f"Description: {result['description'][:200]}...\n\n"
|
|
|
| return formatted_results
|
|
|
| except ImportError:
|
| return "Required packages not installed. Please install googleapiclient with: pip install google-api-python-client"
|
| except Exception as e:
|
| logger.error(f"Error searching YouTube: {str(e)}")
|
| return f"Error searching YouTube: {str(e)}"
|
|
|
| @tool
|
| def get_youtube_transcript(video_url: str) -> str:
|
| """
|
| Get the transcript of a YouTube video.
|
|
|
| Args:
|
| video_url: URL of the YouTube video
|
|
|
| Returns:
|
| String containing the transcript
|
| """
|
| try:
|
| from youtube_transcript_api import YouTubeTranscriptApi
|
|
|
|
|
| video_id_match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11}).*', video_url)
|
| if not video_id_match:
|
| return f"Invalid YouTube URL: {video_url}"
|
|
|
| video_id = video_id_match.group(1)
|
|
|
|
|
| transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
|
|
|
|
|
| transcript = ""
|
| for entry in transcript_list:
|
| start_time = entry['start']
|
| text = entry['text']
|
| minutes = int(start_time // 60)
|
| seconds = int(start_time % 60)
|
| timestamp = f"{minutes:02d}:{seconds:02d}"
|
| transcript += f"[{timestamp}] {text}\n"
|
|
|
| return transcript
|
|
|
| except ImportError:
|
| return "Required packages not installed. Please install youtube-transcript-api with: pip install youtube-transcript-api"
|
| except Exception as e:
|
| logger.error(f"Error getting YouTube transcript: {str(e)}")
|
| return f"Error getting YouTube transcript: {str(e)}"
|
|
|
| @tool
|
| def transcribe_audio(file_path: str, api_key: str = WHISPER_API_KEY) -> str:
|
| """
|
| Transcribe audio file to text using OpenAI's Whisper API with Google Speech Recognition fallback.
|
|
|
| Args:
|
| file_path: Path to the audio file
|
| api_key: OpenAI API key for Whisper
|
|
|
| Returns:
|
| String containing the transcribed text
|
| """
|
| try:
|
| import speech_recognition as sr
|
| from pydub import AudioSegment
|
| import os
|
|
|
|
|
| if not os.path.exists(file_path):
|
| return f"Error: File not found at {file_path}"
|
|
|
|
|
| file_ext = os.path.splitext(file_path)[1].lower()
|
|
|
|
|
| temp_wav_path = None
|
| if file_ext != '.wav':
|
| try:
|
| logger.info(f"Converting {file_ext} file to WAV format")
|
| temp_wav_path = os.path.join(os.path.dirname(file_path), "temp_audio.wav")
|
| audio = AudioSegment.from_file(file_path)
|
| audio.export(temp_wav_path, format="wav")
|
| file_path = temp_wav_path
|
| logger.info(f"Converted audio saved to {temp_wav_path}")
|
| except Exception as e:
|
| logger.error(f"Error converting audio: {str(e)}")
|
| return f"Error converting audio: {str(e)}"
|
|
|
|
|
| recognizer = sr.Recognizer()
|
|
|
|
|
| with sr.AudioFile(file_path) as source:
|
| audio_data = recognizer.record(source)
|
|
|
|
|
| if api_key:
|
| try:
|
| logger.info("Attempting transcription with OpenAI Whisper API")
|
| import openai
|
|
|
| client = openai.OpenAI(api_key=api_key)
|
|
|
| with open(file_path, "rb") as audio_file:
|
| transcript = client.audio.transcriptions.create(
|
| model="whisper-1",
|
| file=audio_file
|
| )
|
|
|
|
|
| if temp_wav_path and os.path.exists(temp_wav_path):
|
| os.remove(temp_wav_path)
|
|
|
| return transcript.text
|
| except Exception as e:
|
| logger.error(f"Error with Whisper API: {str(e)}")
|
| logger.info("Falling back to Google Speech Recognition")
|
|
|
|
|
| try:
|
| logger.info("Using Google Speech Recognition")
|
| text = recognizer.recognize_google(audio_data)
|
|
|
|
|
| if temp_wav_path and os.path.exists(temp_wav_path):
|
| os.remove(temp_wav_path)
|
|
|
| return text
|
| except sr.UnknownValueError:
|
| return "Google Speech Recognition could not understand the audio"
|
| except sr.RequestError as e:
|
| return f"Could not request results from Google Speech Recognition service: {str(e)}"
|
|
|
| except ImportError:
|
| return "Required packages not installed. Please install pydub and SpeechRecognition with: pip install pydub SpeechRecognition"
|
| except Exception as e:
|
| logger.error(f"Error transcribing audio: {str(e)}")
|
| return f"Error transcribing audio: {str(e)}"
|
|
|
| @tool
|
| def extract_ingredients_from_audio(file_path: str, api_key: str = WHISPER_API_KEY) -> str:
|
| """
|
| Extract ingredients list from a recipe audio file.
|
|
|
| Args:
|
| file_path: Path to the audio file
|
| api_key: OpenAI API key for Whisper
|
|
|
| Returns:
|
| String containing the extracted ingredients
|
| """
|
| try:
|
|
|
| transcript = transcribe_audio(file_path, api_key)
|
|
|
| if transcript.startswith("Error") or transcript.startswith("Could not"):
|
| return transcript
|
|
|
|
|
| ingredients_section = None
|
|
|
|
|
| patterns = [
|
| r"(?i)ingredients[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
|
| r"(?i)you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
|
| r"(?i)what you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)",
|
| r"(?i)here's what you(?:'ll| will) need[:\s]+(.+?)(?:instructions|directions|method|steps|preparation|$)"
|
| ]
|
|
|
| for pattern in patterns:
|
| match = re.search(pattern, transcript, re.DOTALL)
|
| if match:
|
| ingredients_section = match.group(1).strip()
|
| break
|
|
|
| if not ingredients_section:
|
|
|
| potential_ingredients = []
|
|
|
|
|
| measurements = r"(?i)(\d+(?:\s+\d+/\d+)?|\d+/\d+)\s*(?:cup|cups|tablespoon|tbsp|teaspoon|tsp|ounce|oz|pound|lb|gram|g|kg|ml|l|pinch|dash|handful|clove|cloves|bunch|can|package|pkg|bottle)"
|
| measurement_matches = re.finditer(measurements, transcript)
|
|
|
| for match in measurement_matches:
|
|
|
| start = max(0, match.start() - 50)
|
| end = min(len(transcript), match.end() + 50)
|
| context = transcript[start:end]
|
| potential_ingredients.append(context)
|
|
|
| if potential_ingredients:
|
| ingredients_section = "\n".join(potential_ingredients)
|
| else:
|
| return "Could not identify ingredients section in the audio. Please provide a clearer recording or manually list the ingredients."
|
|
|
|
|
| ingredients_lines = ingredients_section.split("\n")
|
| formatted_ingredients = []
|
|
|
| for line in ingredients_lines:
|
| line = line.strip()
|
| if line:
|
|
|
| if not re.search(r"(?i)(instruction|direction|method|step|preparation|preheat|mix|stir|cook|bake)", line):
|
| formatted_ingredients.append(f"- {line}")
|
|
|
| if not formatted_ingredients:
|
|
|
| return f"Extracted Ingredients:\n{ingredients_section}"
|
|
|
| return "Extracted Ingredients:\n" + "\n".join(formatted_ingredients)
|
|
|
| except Exception as e:
|
| logger.error(f"Error extracting ingredients: {str(e)}")
|
| return f"Error extracting ingredients: {str(e)}"
|
|
|
|
|
| def get_tools(include_search: bool = True, tavily_api_key: str = None,
|
| serpapi_api_key: str = None, youtube_api_key: str = None,
|
| whisper_api_key: str = None):
|
| """
|
| Get all available tools for the agent.
|
|
|
| Args:
|
| include_search: Whether to include search tools
|
| tavily_api_key: Tavily API key for search functionality
|
| serpapi_api_key: SerpAPI key for search functionality
|
| youtube_api_key: YouTube API key for video content access
|
| whisper_api_key: OpenAI Whisper API key for audio transcription
|
|
|
| Returns:
|
| List of tools
|
| """
|
| tools = [
|
| fetch_question_details,
|
| fetch_context_file
|
| ]
|
|
|
|
|
| tools.append(transcribe_audio)
|
| tools.append(extract_ingredients_from_audio)
|
| logger.info("Audio processing tools added to agent tools")
|
|
|
|
|
| if youtube_api_key:
|
| tools.append(search_youtube)
|
| tools.append(get_youtube_transcript)
|
| logger.info("YouTube tools added to agent tools")
|
|
|
|
|
| if include_search:
|
|
|
| if tavily_api_key:
|
| try:
|
| from langchain_community.tools.tavily_search import TavilySearchResults
|
|
|
| tavily_search = TavilySearchResults(
|
| max_results=7,
|
| api_key=tavily_api_key
|
| )
|
| tools.append(tavily_search)
|
| logger.info("Tavily search tool added to agent tools")
|
| except ImportError:
|
| logger.warning("Could not import TavilySearchResults. Tavily search will be disabled.")
|
| except Exception as e:
|
| logger.warning(f"Error initializing Tavily search tool: {e}")
|
|
|
|
|
| if serpapi_api_key:
|
| try:
|
| from langchain_community.utilities.serpapi import SerpAPIWrapper
|
| from langchain.tools import Tool
|
|
|
| search = SerpAPIWrapper(serpapi_api_key=serpapi_api_key)
|
| serpapi_tool = Tool(
|
| name="SerpAPI Search",
|
| description="A search engine. Useful for when you need to answer questions about current events or the current state of the world. Input should be a search query.",
|
| func=search.run
|
| )
|
| tools.append(serpapi_tool)
|
| logger.info("SerpAPI search tool added to agent tools")
|
| except ImportError:
|
| logger.warning("Could not import SerpAPIWrapper. SerpAPI search will be disabled.")
|
| except Exception as e:
|
| logger.warning(f"Error initializing SerpAPI search tool: {e}")
|
|
|
| return tools |