FD900 commited on
Commit
70f978a
·
verified ·
1 Parent(s): a6908e4

Update tools/youtube_video_tool.py

Browse files
Files changed (1) hide show
  1. tools/youtube_video_tool.py +204 -0
tools/youtube_video_tool.py CHANGED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from tools.speech_recognition_tool import SpeechRecognitionTool
3
+ from transformers import HfInference
4
+ from io import BytesIO
5
+ import yt_dlp
6
+ import av
7
+ import subprocess
8
+ import requests
9
+ import base64
10
+ import tempfile
11
+ import re
12
+ import os
13
+
14
+ class YouTubeVideoTool(Tool):
15
+ name = 'youtube_video'
16
+ description = 'Extract information from YouTube video content using vision, audio, and captions.'
17
+
18
+ inputs = {
19
+ 'url': {'type': 'string', 'description': 'YouTube video URL'},
20
+ 'query': {'type': 'string', 'description': 'Query about the video content'},
21
+ }
22
+ output_type = 'string'
23
+
24
+ def __init__(
25
+ self,
26
+ endpoint_url: str,
27
+ video_quality: int = 360,
28
+ frames_interval: float = 2,
29
+ chunk_duration: float = 2,
30
+ speech_tool: SpeechRecognitionTool | None = None,
31
+ debug: bool = False,
32
+ **kwargs
33
+ ):
34
+ self.video_quality = video_quality
35
+ self.frames_interval = frames_interval
36
+ self.chunk_duration = chunk_duration
37
+ self.speech_tool = speech_tool
38
+ self.debug = debug
39
+
40
+ self.client = HfInference(endpoint_url=endpoint_url)
41
+ super().__init__(**kwargs)
42
+
43
+ def forward(self, url: str, query: str) -> str:
44
+ full_answer = ''
45
+ for chunk in self._split_video(url):
46
+ prompt = self._compose_prompt(chunk, query, full_answer)
47
+ resp = self.client.text_generation(prompt, model='mistralai/Mistral-7B-Instruct-v0.1', max_new_tokens=512)
48
+ full_answer = resp.generated_text.strip()
49
+ return full_answer if full_answer != 'I need to keep watching.' else ''
50
+
51
+ def _split_video(self, url):
52
+ video = self._process(url)
53
+ dur = video['duration']
54
+ start = 0
55
+ while start < dur:
56
+ end = min(start + self.chunk_duration, dur)
57
+ yield self._chunk(video, start, end)
58
+ start += self.chunk_duration
59
+
60
+ def _chunk(self, video, start, end):
61
+ caps = [c for c in video['captions'] if c['start'] <= end and c['end'] >= start]
62
+ frames = [f for f in video['frames'] if start <= f['timestamp'] <= end]
63
+ return {
64
+ 'title': video['title'],
65
+ 'description': video['description'],
66
+ 'start': start,
67
+ 'end': end,
68
+ 'captions': '\n'.join(c['text'] for c in caps),
69
+ 'frames': frames,
70
+ }
71
+
72
+ def _compose_prompt(self, chunk, query, previous):
73
+ parts = [
74
+ f"VIDEO TITLE:\n{chunk['title']}",
75
+ f"DESCRIPTION:\n{chunk['description']}",
76
+ f"CAPTIONS:\n{chunk['captions']}",
77
+ ]
78
+ if previous:
79
+ parts.append(f"PRIOR ANSWER:\n{previous}")
80
+ parts.append(f"QUESTION:\n{query}")
81
+ return "\n\n".join(parts)
82
+
83
+ def _process(self, url):
84
+ info = self._get_info(url)
85
+ captions = self._get_captions(info)
86
+ frames = self._get_frames(info)
87
+ return {
88
+ 'id': info['id'],
89
+ 'title': info['title'],
90
+ 'description': info['description'],
91
+ 'duration': info['duration'],
92
+ 'captions': captions,
93
+ 'frames': frames,
94
+ }
95
+
96
+ def _get_info(self, url):
97
+ ydl_opts = {
98
+ 'quiet': True,
99
+ 'skip_download': True,
100
+ 'format': f"bestvideo[height<={self.video_quality}]+bestaudio/best",
101
+ 'forceurl': True,
102
+ }
103
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
104
+ return ydl.extract_info(url, download=False)
105
+
106
+ def _get_captions(self, info):
107
+ lang = 'en'
108
+ caps = self._extract_captions(lang, info.get('subtitles', {}), info.get('automatic_captions', {}))
109
+ if not caps and self.speech_tool:
110
+ audio_url = self._select_audio_format(info['formats'])
111
+ audio = self._capture_audio(audio_url)
112
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
113
+ f.write(audio.read())
114
+ f.flush()
115
+ path = f.name
116
+ try:
117
+ txt = self.speech_tool(audio=path, with_time_markers=True)
118
+ return self._parse_transcript(txt)
119
+ finally:
120
+ os.remove(path)
121
+ return caps
122
+
123
+ def _parse_transcript(self, raw):
124
+ chunks = []
125
+ for match in re.finditer(r'\[(\d+\.\d+)\]\n(.+?)\n\[(\d+\.\d+)\]', raw, re.DOTALL):
126
+ s, t, e = match.groups()
127
+ chunks.append({'start': float(s), 'end': float(e), 'text': t.strip()})
128
+ return chunks
129
+
130
+ def _extract_captions(self, lang, subs, auto):
131
+ import pysrt, webvtt
132
+ from io import StringIO
133
+
134
+ def to_sec(t): return t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000
135
+
136
+ def from_srt(srt_url):
137
+ resp = requests.get(srt_url)
138
+ return [{
139
+ 'start': to_sec(sub.start),
140
+ 'end': to_sec(sub.end),
141
+ 'text': sub.text.strip(),
142
+ } for sub in pysrt.from_string(resp.text)]
143
+
144
+ def from_vtt(vtt_url):
145
+ def vtt_to_sec(ts):
146
+ h, m, s = ts.split(':')
147
+ s, ms = s.split('.')
148
+ return int(h)*3600 + int(m)*60 + int(s) + int(ms)/1000
149
+ resp = requests.get(vtt_url)
150
+ out = []
151
+ for c in webvtt.read_buffer(StringIO(resp.text)):
152
+ out.append({'start': vtt_to_sec(c.start), 'end': vtt_to_sec(c.end), 'text': c.text.strip()})
153
+ return out
154
+
155
+ cap_track = subs.get(lang) or auto.get(lang) or []
156
+ for track in cap_track:
157
+ if track['ext'] == 'srt': return from_srt(track['url'])
158
+ if track['ext'] == 'vtt': return from_vtt(track['url'])
159
+ return []
160
+
161
+ def _get_frames(self, info):
162
+ video_url = self._select_video_format(info['formats'])['url']
163
+ return self._extract_frames(video_url)
164
+
165
+ def _extract_frames(self, url):
166
+ with tempfile.NamedTemporaryFile(suffix='.mkv', delete=False) as tmp:
167
+ subprocess.run(['ffmpeg', '-y', '-i', url, '-f', 'matroska', tmp.name], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
168
+ container = av.open(tmp.name)
169
+ stream = container.streams.video[0]
170
+ tb = stream.time_base
171
+ frames = []
172
+ next_t = 0
173
+ for frame in container.decode(stream):
174
+ if frame.pts is None: continue
175
+ ts = float(frame.pts * tb)
176
+ if ts >= next_t:
177
+ frames.append({'timestamp': ts, 'image': frame.to_image()})
178
+ next_t += self.frames_interval
179
+ container.close()
180
+ os.remove(tmp.name)
181
+ return frames
182
+
183
+ def _select_video_format(self, formats):
184
+ for f in formats:
185
+ if f.get('vcodec') != 'none' and f.get('height') == self.video_quality:
186
+ return f
187
+ raise ValueError('No matching video format found')
188
+
189
+ def _select_audio_format(self, formats):
190
+ audio_formats = [f for f in formats if f.get('vcodec') == 'none' and f.get('acodec') != 'none']
191
+ audio_formats.sort(key=lambda f: (-f.get('abr', 0), f['ext'] != 'm4a'))
192
+ return audio_formats[0]['url']
193
+
194
+ def _capture_audio(self, audio_url):
195
+ cmd = [
196
+ 'ffmpeg', '-i', audio_url,
197
+ '-f', 'wav', '-acodec', 'pcm_s16le', '-ac', '1', '-ar', '16000', '-'
198
+ ]
199
+ proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
200
+ if proc.returncode != 0:
201
+ raise RuntimeError('Audio capture failed')
202
+ buf = BytesIO(proc.stdout)
203
+ buf.seek(0)
204
+ return buf