Ramja's picture
Update tools.py
fe76cc0 verified
from smolagents import DuckDuckGoSearchTool, Tool, VisitWebpageTool
from langchain_community.agent_toolkits.load_tools import load_tools
import os
from typing import List, Optional
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
from smolagents import tool
@tool
def youtube_transcript_search(
video_id: str,
query: str,
top_k: Optional[int] = 5
) -> List[dict]:
"""
Search a YouTube transcript for occurrences of a query string.
Args:
video_id (str): The YouTube video ID (after 'v=' in the URL).
query (str): The search term to look for in the transcript (case-insensitive).
top_k (Optional[int]): Maximum number of matches to return (default: 5).
Returns:
List[dict]: A list of up to top_k matches, each a dict with:
- 'time': float start time in seconds
- 'text_snippet': snippet of transcript containing the query
"""
try:
# Fetch transcript (choose auto-generated or manually created)
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
# prefer generated if manual unavailable
transcript = transcript_list.find_manually_created_transcript(['en']) or transcript_list.find_generated_transcript(['en'])
segments = transcript.fetch()
except (TranscriptsDisabled, NoTranscriptFound) as e:
return [{"time": 0.0, "text_snippet": f"No transcript found: {str(e)}"}]
# Search query in segments
query_lower = query.lower()
hits = []
for seg in segments:
if query_lower in seg["text"].lower():
hits.append({
"time": seg["start"],
"text_snippet": seg["text"]
})
if len(hits) >= top_k:
break
if not hits:
return [{"time": 0.0, "text_snippet": "No matches found."}]
return hits
# Initialize the search tools
duck_search_tool = DuckDuckGoSearchTool()
visit_web_page_tool = VisitWebpageTool()
google_search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])