Ramja commited on
Commit
fe76cc0
·
verified ·
1 Parent(s): f0a86d4

Update tools.py

Browse files

Add a tool for youtube transcripts

Files changed (1) hide show
  1. tools.py +48 -0
tools.py CHANGED
@@ -1,7 +1,55 @@
1
  from smolagents import DuckDuckGoSearchTool, Tool, VisitWebpageTool
2
  from langchain_community.agent_toolkits.load_tools import load_tools
3
  import os
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Initialize the search tools
7
 
 
1
  from smolagents import DuckDuckGoSearchTool, Tool, VisitWebpageTool
2
  from langchain_community.agent_toolkits.load_tools import load_tools
3
  import os
4
+ from typing import List, Optional
5
+ from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
6
+ from smolagents import tool
7
 
8
+ @tool
9
+ def youtube_transcript_search(
10
+ video_id: str,
11
+ query: str,
12
+ top_k: Optional[int] = 5
13
+ ) -> List[dict]:
14
+ """
15
+ Search a YouTube transcript for occurrences of a query string.
16
+
17
+ Args:
18
+ video_id (str): The YouTube video ID (after 'v=' in the URL).
19
+ query (str): The search term to look for in the transcript (case-insensitive).
20
+ top_k (Optional[int]): Maximum number of matches to return (default: 5).
21
+
22
+ Returns:
23
+ List[dict]: A list of up to top_k matches, each a dict with:
24
+ - 'time': float start time in seconds
25
+ - 'text_snippet': snippet of transcript containing the query
26
+ """
27
+ try:
28
+ # Fetch transcript (choose auto-generated or manually created)
29
+ transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
30
+ # prefer generated if manual unavailable
31
+ transcript = transcript_list.find_manually_created_transcript(['en']) or transcript_list.find_generated_transcript(['en'])
32
+ segments = transcript.fetch()
33
+
34
+ except (TranscriptsDisabled, NoTranscriptFound) as e:
35
+ return [{"time": 0.0, "text_snippet": f"No transcript found: {str(e)}"}]
36
+
37
+ # Search query in segments
38
+ query_lower = query.lower()
39
+ hits = []
40
+ for seg in segments:
41
+ if query_lower in seg["text"].lower():
42
+ hits.append({
43
+ "time": seg["start"],
44
+ "text_snippet": seg["text"]
45
+ })
46
+ if len(hits) >= top_k:
47
+ break
48
+
49
+ if not hits:
50
+ return [{"time": 0.0, "text_snippet": "No matches found."}]
51
+
52
+ return hits
53
 
54
  # Initialize the search tools
55