Final_Assignment_Template / tools /youtube_video_tool.py
FD900's picture
Update tools/youtube_video_tool.py
70f978a verified
raw
history blame
7.59 kB
from smolagents import Tool
from tools.speech_recognition_tool import SpeechRecognitionTool
from transformers import HfInference
from io import BytesIO
import yt_dlp
import av
import subprocess
import requests
import base64
import tempfile
import re
import os
class YouTubeVideoTool(Tool):
name = 'youtube_video'
description = 'Extract information from YouTube video content using vision, audio, and captions.'
inputs = {
'url': {'type': 'string', 'description': 'YouTube video URL'},
'query': {'type': 'string', 'description': 'Query about the video content'},
}
output_type = 'string'
def __init__(
self,
endpoint_url: str,
video_quality: int = 360,
frames_interval: float = 2,
chunk_duration: float = 2,
speech_tool: SpeechRecognitionTool | None = None,
debug: bool = False,
**kwargs
):
self.video_quality = video_quality
self.frames_interval = frames_interval
self.chunk_duration = chunk_duration
self.speech_tool = speech_tool
self.debug = debug
self.client = HfInference(endpoint_url=endpoint_url)
super().__init__(**kwargs)
def forward(self, url: str, query: str) -> str:
full_answer = ''
for chunk in self._split_video(url):
prompt = self._compose_prompt(chunk, query, full_answer)
resp = self.client.text_generation(prompt, model='mistralai/Mistral-7B-Instruct-v0.1', max_new_tokens=512)
full_answer = resp.generated_text.strip()
return full_answer if full_answer != 'I need to keep watching.' else ''
def _split_video(self, url):
video = self._process(url)
dur = video['duration']
start = 0
while start < dur:
end = min(start + self.chunk_duration, dur)
yield self._chunk(video, start, end)
start += self.chunk_duration
def _chunk(self, video, start, end):
caps = [c for c in video['captions'] if c['start'] <= end and c['end'] >= start]
frames = [f for f in video['frames'] if start <= f['timestamp'] <= end]
return {
'title': video['title'],
'description': video['description'],
'start': start,
'end': end,
'captions': '\n'.join(c['text'] for c in caps),
'frames': frames,
}
def _compose_prompt(self, chunk, query, previous):
parts = [
f"VIDEO TITLE:\n{chunk['title']}",
f"DESCRIPTION:\n{chunk['description']}",
f"CAPTIONS:\n{chunk['captions']}",
]
if previous:
parts.append(f"PRIOR ANSWER:\n{previous}")
parts.append(f"QUESTION:\n{query}")
return "\n\n".join(parts)
def _process(self, url):
info = self._get_info(url)
captions = self._get_captions(info)
frames = self._get_frames(info)
return {
'id': info['id'],
'title': info['title'],
'description': info['description'],
'duration': info['duration'],
'captions': captions,
'frames': frames,
}
def _get_info(self, url):
ydl_opts = {
'quiet': True,
'skip_download': True,
'format': f"bestvideo[height<={self.video_quality}]+bestaudio/best",
'forceurl': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
return ydl.extract_info(url, download=False)
def _get_captions(self, info):
lang = 'en'
caps = self._extract_captions(lang, info.get('subtitles', {}), info.get('automatic_captions', {}))
if not caps and self.speech_tool:
audio_url = self._select_audio_format(info['formats'])
audio = self._capture_audio(audio_url)
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
f.write(audio.read())
f.flush()
path = f.name
try:
txt = self.speech_tool(audio=path, with_time_markers=True)
return self._parse_transcript(txt)
finally:
os.remove(path)
return caps
def _parse_transcript(self, raw):
chunks = []
for match in re.finditer(r'\[(\d+\.\d+)\]\n(.+?)\n\[(\d+\.\d+)\]', raw, re.DOTALL):
s, t, e = match.groups()
chunks.append({'start': float(s), 'end': float(e), 'text': t.strip()})
return chunks
def _extract_captions(self, lang, subs, auto):
import pysrt, webvtt
from io import StringIO
def to_sec(t): return t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000
def from_srt(srt_url):
resp = requests.get(srt_url)
return [{
'start': to_sec(sub.start),
'end': to_sec(sub.end),
'text': sub.text.strip(),
} for sub in pysrt.from_string(resp.text)]
def from_vtt(vtt_url):
def vtt_to_sec(ts):
h, m, s = ts.split(':')
s, ms = s.split('.')
return int(h)*3600 + int(m)*60 + int(s) + int(ms)/1000
resp = requests.get(vtt_url)
out = []
for c in webvtt.read_buffer(StringIO(resp.text)):
out.append({'start': vtt_to_sec(c.start), 'end': vtt_to_sec(c.end), 'text': c.text.strip()})
return out
cap_track = subs.get(lang) or auto.get(lang) or []
for track in cap_track:
if track['ext'] == 'srt': return from_srt(track['url'])
if track['ext'] == 'vtt': return from_vtt(track['url'])
return []
def _get_frames(self, info):
video_url = self._select_video_format(info['formats'])['url']
return self._extract_frames(video_url)
def _extract_frames(self, url):
with tempfile.NamedTemporaryFile(suffix='.mkv', delete=False) as tmp:
subprocess.run(['ffmpeg', '-y', '-i', url, '-f', 'matroska', tmp.name], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
container = av.open(tmp.name)
stream = container.streams.video[0]
tb = stream.time_base
frames = []
next_t = 0
for frame in container.decode(stream):
if frame.pts is None: continue
ts = float(frame.pts * tb)
if ts >= next_t:
frames.append({'timestamp': ts, 'image': frame.to_image()})
next_t += self.frames_interval
container.close()
os.remove(tmp.name)
return frames
def _select_video_format(self, formats):
for f in formats:
if f.get('vcodec') != 'none' and f.get('height') == self.video_quality:
return f
raise ValueError('No matching video format found')
def _select_audio_format(self, formats):
audio_formats = [f for f in formats if f.get('vcodec') == 'none' and f.get('acodec') != 'none']
audio_formats.sort(key=lambda f: (-f.get('abr', 0), f['ext'] != 'm4a'))
return audio_formats[0]['url']
def _capture_audio(self, audio_url):
cmd = [
'ffmpeg', '-i', audio_url,
'-f', 'wav', '-acodec', 'pcm_s16le', '-ac', '1', '-ar', '16000', '-'
]
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if proc.returncode != 0:
raise RuntimeError('Audio capture failed')
buf = BytesIO(proc.stdout)
buf.seek(0)
return buf